from functools import partialmethod
from typing import Optional, List
import torch
import torch.nn as nn
from .common import Linear
from unicore.utils import (
    permute_final_dims,
)
from unicore.modules import (
    LayerNorm,
)


class TriangleMultiplication(nn.Module):
    def __init__(self, d_pair, d_hid, outgoing=True):
        super(TriangleMultiplication, self).__init__()
        self.outgoing = outgoing

        self.linear_ab_p = Linear(d_pair, d_hid * 2)
        self.linear_ab_g = Linear(d_pair, d_hid * 2, init="gating")

        self.linear_g = Linear(d_pair, d_pair, init="gating")
        self.linear_z = Linear(d_hid, d_pair, init="final")

        self.layer_norm_in = LayerNorm(d_pair)
        self.layer_norm_out = LayerNorm(d_hid)

        self._alphafold_original_mode = False

    def _chunk_2d(
        self,
        z: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        chunk_size: int = None,
    ) -> torch.Tensor:

        # avoid too small chunk size
        chunk_size = max(chunk_size, 256)
        new_z = z.new_zeros(z.shape)
        dim1 = z.shape[-3]

        def _slice_linear(z, linear: Linear, a=True):
            d_hid = linear.bias.shape[0] // 2
            index = 0 if a else d_hid
            p = (
                nn.functional.linear(z, linear.weight[index : index + d_hid])
                + linear.bias[index : index + d_hid]
            )
            return p

        def _chunk_projection(z, mask, a=True):
            p = _slice_linear(z, self.linear_ab_p, a) * mask
            p *= torch.sigmoid(_slice_linear(z, self.linear_ab_g, a))
            return p

        num_chunk = (dim1 + chunk_size - 1) // chunk_size
        for i in range(num_chunk):
            chunk_start = i * chunk_size
            chunk_end = min(chunk_start + chunk_size, dim1)
            if self.outgoing:
                a_chunk = _chunk_projection(
                    z[..., chunk_start:chunk_end, :, :],
                    mask[..., chunk_start:chunk_end, :, :],
                    a=True,
                )
                a_chunk = permute_final_dims(a_chunk, (2, 0, 1))
            else:
                a_chunk = _chunk_projection(
                    z[..., :, chunk_start:chunk_end, :],
                    mask[..., :, chunk_start:chunk_end, :],
                    a=True,
                )
                a_chunk = a_chunk.transpose(-1, -3)

            for j in range(num_chunk):
                j_chunk_start = j * chunk_size
                j_chunk_end = min(j_chunk_start + chunk_size, dim1)
                if self.outgoing:
                    b_chunk = _chunk_projection(
                        z[..., j_chunk_start:j_chunk_end, :, :],
                        mask[..., j_chunk_start:j_chunk_end, :, :],
                        a=False,
                    )
                    b_chunk = b_chunk.transpose(-1, -3)
                else:
                    b_chunk = _chunk_projection(
                        z[..., :, j_chunk_start:j_chunk_end, :],
                        mask[..., :, j_chunk_start:j_chunk_end, :],
                        a=False,
                    )
                    b_chunk = permute_final_dims(b_chunk, (2, 0, 1))
                x_chunk = torch.matmul(a_chunk, b_chunk)
                del b_chunk
                x_chunk = permute_final_dims(x_chunk, (1, 2, 0))
                x_chunk = self.layer_norm_out(x_chunk)
                x_chunk = self.linear_z(x_chunk)
                x_chunk *= torch.sigmoid(
                    self.linear_g(
                        z[..., chunk_start:chunk_end, j_chunk_start:j_chunk_end, :]
                    )
                )
                new_z[
                    ..., chunk_start:chunk_end, j_chunk_start:j_chunk_end, :
                ] = x_chunk
                del x_chunk
            del a_chunk
        return new_z

    def forward(
        self,
        z: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        chunk_size=None,
    ) -> torch.Tensor:

        mask = mask.unsqueeze(-1)
        if not self._alphafold_original_mode:
            # divided by 1/sqrt(dim) for numerical stability
            mask = mask * (mask.shape[-2] ** -0.5)

        z = self.layer_norm_in(z)
        if not self.training and chunk_size is not None:
            return self._chunk_2d(z, mask, chunk_size=chunk_size)

        g = nn.functional.linear(z, self.linear_g.weight)
        ab = self.linear_ab_p(z) * mask
        ab *= torch.sigmoid(self.linear_ab_g(z))
        a, b = torch.chunk(ab, 2, dim=-1)
        del z, ab

        if self.outgoing:
            a = permute_final_dims(a, (2, 0, 1))
            b = b.transpose(-1, -3)
        else:
            b = permute_final_dims(b, (2, 0, 1))
            a = a.transpose(-1, -3)
        x = torch.matmul(a, b)
        del a, b

        x = permute_final_dims(x, (1, 2, 0))

        x = self.layer_norm_out(x)
        x = nn.functional.linear(x, self.linear_z.weight)
        return x, g

    def get_output_bias(self):
        return self.linear_z.bias, self.linear_g.bias


class TriangleMultiplicationOutgoing(TriangleMultiplication):
    __init__ = partialmethod(TriangleMultiplication.__init__, outgoing=True)


class TriangleMultiplicationIncoming(TriangleMultiplication):
    __init__ = partialmethod(TriangleMultiplication.__init__, outgoing=False)
